Few-Shot Learning for Rooftop Segmentation in Satellite Imagery¶

Open In Colab

This tutorial introduces few-shot learning techniques for semantic segmentation in satellite imagery using high-resolution images from Geneva, Switzerland. We will demonstrate how Prototypical Networks can learn meaningful rooftop representations from only a few labeled examples and generalize to new geographic areas with minimal annotation effort.

Learning Outcomes¶

By the end of the tutorial, you will be able to:

  • Understand the core concepts behind Few-Shot Learning and Few-Shot Semantic Segmentation
  • Work with satellite imagery, geographic splits, and pixel-level segmentation masks
  • Implement Prototypical Networks with episodic training for segmentation tasks
  • Evaluate model performance using metrics such as IoU and interpret FSL model behavior
  • Reflect on policy-relevant applications such as rooftop solar assessment and data-scarce mapping tasks

Prerequisites¶

  • Intermediate Python programming
  • Familiarity with PyTorch
  • Basics of Machine and Deep Learning
  • Understanding of convolutional neural networks

Apart from that, you'll need a strong interest and some motivation to learn a new method 😉

Table of Contents¶

  • Memo
  • Overview
  • Background & Introduction
  • Theoretical Foundation
  • Data Description
  • Methodology
  • Results
  • Optional: Testing
  • Discussion & Limitations
  • ⭐ Challenge ⭐
  • Further Resources
  • References

Memo¶

This notebook offers an in-depth look at few-shot learning (FSL), an advanced deep learning technique that enables models to generalize effectively from only a small number of training examples. FSL becomes particularly valuable when traditional supervised learning is impractical. For instance, when annotation is expensive, requires specialized expertise, or when data is scarce.

Few-shot learning is especially relevant in the context of public policy, where we often face limited access to comprehensive, well-curated datasets or lack the resources to collect large amounts of training data. Applications span a wide range of policy areas, including disaster classification (Lee et al., 2025), urban planning (Hu et al., 2022) and health policy, where FSL has supported COVID-19 detection (Jadon, 2021) and the diagnosis of rare genetic diseases (Alsentzer et al., 2025).

Beyond classification tasks, FSL also performs well in segmentation problems, helping identify different types of buildings, vegetation such as forest cover (Puthumanaillam & Verma, 2023) and performing well in remote sensing imagery segmentation (Chen et al., 2022). All of these examples share a common challenge: limited data.

Few-shot learning helps address this constraint, and we encounter similar situations within our own institution. A clear illustration is our Wildlife Trade Monitoring Project (WTMP), where labeled examples of rare, endangered species are extremely scarce and often inconsistent. This emphasizes just how valuable few-shot learning approaches can be.
Today, a wide range of few-shot learning frameworks and architectures are available. Some of the most widely used include model-agnostic meta-learning (Finn, Abbeel & Levine, 2017), prototypical networks (Snell, Swersky & Zemel, 2017), and relation networks (Sung et al., 2018). Over the past few years, the field has continued to evolve, giving rise to more specialized and state-of-the-art techniques, such as SegPPD-FS for semantic segmentation (Ge et al., 2025) and CDCNet for classification tasks (Li et al., 2025).

Because of its importance to public policy and its potential value for our organization, this notebook takes a deep dive into prototypical networks semantic segmentation of building rooftops. This is an essential step for many downstream tasks, such as assessing rooftop suitability for solar panel installation. By walking through this method, the notebook aims to help our team to better understand and apply this powerful technique in our day-to-day work.

Overview and Introduction¶

Dear colleagues,

In this notebook, we will introduce the concept of few-shot learning (FSL). As briefly mentioned earlier, FSL is an advanced deep learning paradigm that enables models to generalize from only a handful of annotated samples. In particular, we will focus on prototypical networks, which learn a metric space where each class is represented by a prototype (usually the mean embedding of its support examples) and then classify query samples based on their distance to these prototypes (more details will be provided below).

This method can be applied to semantic segmentation tasks, allowing us to partition an image into meaningful regions. Few-shot learning has become increasingly relevant with the rise of multi-modal large transformer models such as GPT-4, Claude, and Gemini, which demonstrate remarkable few-shot capabilities by learning from minimal examples across vision and language tasks. The core principles we explore here, such as learning from limited data and generalizing to new examples, underpin many of these modern foundation models, making few-shot learning a fundamental concept in contemporary deep learning. Here we start from the very basics.

To guide you through this approach, we will work with the Geneva Satellite Images dataset. This dataset contains high-resolution (250×250 PNG tiles) satellite imagery of Geneva, Switzerland, along with pixel-level segmentation masks that delineate building rooftops suitable for PV deployment. It was developed by EPFL’s Solar Energy and Building Physics Laboratory and is publicly available on HuggingFace (Castello et al., 2021).

The goal of this notebook is to demonstrate how few-shot learning can be used for rooftop semantic segmentation, with applications such as assessing the suitability of buildings for solar panels.

Throughout the notebook, we will walk you through how to:

  • Load and preprocess satellite imagery
  • Build a few-shot segmentation model
  • Train the model using prototypical networks
  • Evaluate segmentation performance

Please follow the code and try to understand how to apply few-shot learning to segment (suitable) rooftops in new satellite images. You can run each cell in sequence. At every step, we provide clear explanations of what is happening. A pre-recorded video is also available if you prefer a more visual learning format.

If you have any questions about few-shot learning or need help with its implementation, please feel free to reach out. Our contact information is provided in the README.md.

Happy learning!

Your Public Innovation Lab (PILAB)


Theoretical Foundation¶

Key Concepts¶

  • Few-Shot Learning: Few-shot learning aims to train models that can recognize new classes from only a handful of labeled examples by leveraging prior knowledge learned from many related tasks. Unlike traditional deep learning that requires thousands of labeled examples per class, few-shot learning dramatically reduces the annotation burden by learning a generalizable representation that transfers effectively to new tasks with minimal data.
Watch on YouTube

Click on the image above to open YouTube and watch the video (link also available here).

  • Semantic Segmentation: In the context of semantic segmentation, this means learning to assign a class label to every pixel in an image, even when only a few annotated images are available for each class. This pixel-level classification is significantly more challenging than image-level tasks because it requires dense predictions across the entire image, with the model learning precise boundaries between different regions from very limited supervision.

  • Prototypical Networks: Prototypical Networks tackle this by mapping pixels (or patches) into an embedding space where each class is represented by a prototype, typically the mean embedding of its support examples. At inference time, each pixel in a query image is classified by measuring its distance to these class prototypes, enabling accurate pixel-wise segmentation in new scenes with very limited labeled data. Below you find a diagram describing the few-shot rooftop segmentation pipeline used in the notebook:

    1. In each episode, a support satellite image (with its rooftop mask) and a query image go through the same CNN encoder to produce feature maps.
    2. The support mask is applied with Masked Average Pooling (MAP) to turn support features into class prototypes (foreground/roof vs background).
    3. The query feature map is classified pixel-wise by computing similarity to these prototypes, producing the query segmentation.
    4. Training uses losses on both the support prediction ($L_{sup}$) and the query prediction ($L_{que}$).
figure

Figure 1: Overview of the Self-Regularized Prototypical Network Architecture. Adapted from Ding et al. (2022).

Technical Note:
We follow a prototypical-network formulation where support masks are used only to compute foreground/background prototypes via masked average pooling (MAP). The query image is segmented by comparing each query pixel embedding to the prototypes using negative squared Euclidean distance as the similarity score. During training, we optimize only a query-side cross-entropy loss $L_{que}$ against the query ground-truth mask. No auxiliary support loss $L_{sup}$ (self-regularization) branch from the reference diagram is used in our implementation.

Software Requirements¶

Now lets get started by installing the required packages.

Note:
If you are using Google Colab, these packages should be already installed. If needed, or if you are running this notebook locally, please uncomment and run the following cell to install the necessary libraries. Make sure to have a fresh environment set up.

In [ ]:
# uncomment the following line to install required packages
# !pip install -q huggingface_hub torch torchvision sklearn matplotlib tqdm
In [5]:
import sys
import os

os.getcwd()
sys.path.append(os.path.abspath(".."))

import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from huggingface_hub import snapshot_download

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
/Users/giocopp/miniconda3/envs/DL-Tutorial/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Using device: cpu
In [115]:
# Reproducibility / seeds

SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

print(f"✅ Seeds set to {SEED}")
✅ Seeds set to 42

Data Description¶

The dataset you will use consists of 1,050 high-resolution satellite imagery of Geneva's building rooftops with corresponding binary segmentation masks for rooftop segmentation.

Note:
The masks define the rooftop potential for PV installations, and not just simply the rooftop segments. This makes the task more complex. From here on, we do not give too much importance to this subtle distiction, and we treat the task as "rooftop segmentation". We can't expect that the simple "vanilla" model that we will implement will manage to predict masks so precisely, and therefore we will treat the task as simple "rooftop segmentation". Having simpler rooftop masks would have been beneficial.

Dataset Properties:

  • Image Type: RGB satellite imagery
  • Labels: Binary masks (0 = background, 1 = rooftop)
  • Resolution: Variable, resized to 256x256 for training

Data Download¶

The dataset is stored on the 🤗 Hub and each tile comes with:

  • a RGB satellite image, and
  • a binary mask indicating pixels that belong to rooftops suitable for PV installations.

Using huggingface_hub.snapshot_download we download the dataset to the local filesystem.

Note:
To make the data download process more smoothly please put your Hugging Face Token in an .env file at the root of the directory or set it as an (environment) variable.

You can either do it by running the following in your terminal:

export HF_TOKEN='your_token_here'

Or by creating a .env file with the following content:

HF_TOKEN='your_token_here'

Or by simply setting the token directly in the code cell below (not recommended for security reasons):

HF_TOKEN = 'your_token_here'
In [ ]:
## Download dataset once (cached afterward)

def download_geneva_dataset():
    token = os.getenv("HF_TOKEN")  # set your HF_TOKEN env variable if you have one, else None

    dataset_root = snapshot_download(
        repo_id="raphaelattias/overfitteam-geneva-satellite-images",
        repo_type="dataset",
        token=token,
        resume_download=True,
    )

    print(f"Dataset downloaded to: {dataset_root}")
    return dataset_root


dataset_root = download_geneva_dataset()
/Users/giocopp/miniconda3/envs/DL-Tutorial/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Fetching 2111 files: 100%|██████████| 2111/2111 [00:00<00:00, 9967.64it/s]
Dataset downloaded to: /Users/giocopp/.cache/huggingface/hub/datasets--raphaelattias--overfitteam-geneva-satellite-images/snapshots/3f90d759384e4cd38276290521b6f6b03ddfcf87

Data Preprocessing¶

To work with PyTorch, you will define a Dataset class that returns (image, mask) pairs.

Data Transformation¶

  • In a first step, you resize the images to a fixed resolution of 256 × 256, convert them to PyTorch tensors, and normalise with ImageNet mean and standard deviation. This is compatible with the pretrained ResNet18 backbone we will use later.
  • Then, you resize the masks using nearest-neighbour interpolation to avoid mixing labels at boundaries, converted to tensors in [0, 1], and then binarised pixels so that any non-zero value is treated as rooftop (class 1): mask pixels can be either background or rooftop.

The GenevaRooftopDataset class wraps these steps and reads from a structured folder layout with images/ and labels/ subdirectories.

Train-Test: Geographical Split¶

Subsequently, you can define the trainig set to be images from the northern (1310_11) and southern (1310_31) parts of Geneva. You extract the grid ID from the image file. As test set, you can use images from the central (1310_13) part of the city. We expect periphery of the city to have higher variety of rooftop types from the city center, and the areas of the city to be structurally different to a certain degree.

Why we ignore the dataset’s folder splits: For policy use-cases, we often care about spatial generalization (whether a model trained in some areas works in new areas). Therefore, we deliberately construct our own split by holding out entire grid IDs as the test set. This prevents information leakage across space (no tiles from the same grid appearing in both train and test), even if the original folder split is different.

figure
In [ ]:
## Image transformation and dataset class

IMAGE_SIZE = 256  # resize tiles to this

img_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ColorJitter(contrast=0.4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

mask_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=Image.NEAREST),
        transforms.ToTensor(),  # gives float [0,1] for grayscale
    ]
)


def get_grid_id_from_filename(fname):
    parts = fname.split("_")
    # 3th and 3th parts give the grid ID
    return f"{parts[2]}_{parts[3]}"


class GenevaRooftopDataset(Dataset):
    """
    Dataset filtered by geographic grid IDs.

    Can read from multiple splits (train/val/test) at once.
    """

    def __init__(self, root, splits=["train", "val", "test"], category="all", grid_ids=None):
        super().__init__()
        self.root = root
        self.category = category
        self.grid_ids = grid_ids
        self.files_info = []  # list of tuples (split, filename)

        # Collect files from all specified splits
        for split in splits:
            image_dir = os.path.join(root, split, "images", category)
            #label_dir = os.path.join(root, split, "labels", category)

            all_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".png")])

            # Filter by grid_ids if provided
            if grid_ids is not None:
                all_files = [f for f in all_files if get_grid_id_from_filename(f) in grid_ids]

            # Store split info for loading
            self.files_info.extend([(split, f) for f in all_files])

    def __len__(self):
        return len(self.files_info)

    def __getitem__(self, idx):
        split, fname = self.files_info[idx]
        img_path = os.path.join(self.root, split, "images", self.category, fname)
        mask_name = fname.replace(".png", "_label.png")
        mask_path = os.path.join(self.root, split, "labels", self.category, mask_name)

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        img = img_transform(img)
        mask = mask_transform(mask)
        mask = (mask > 0.5).float()

        return img, mask
In [8]:
## Train-Test split

train_grids = ["1301_11", "1301_31"]
test_grids = ["1301_13"]

# Train dataset reads from all three folders
train_base = GenevaRooftopDataset(dataset_root, splits=["train", "val", "test"], grid_ids=train_grids)

# Test dataset can read from just one folder or multiple if needed
test_base = GenevaRooftopDataset(dataset_root, splits=["train", "val", "test"], grid_ids=test_grids)

print(f"Train samples: {len(train_base)}, Test samples: {len(test_base)}")
Train samples: 423, Test samples: 102
In [ ]:
## Show examples

def plot_overlay(image, mask, alpha=0.4):
    """
    Overlay a binary mask on top of an RGB image.

    image: [H,W,3] float in [0,1]
    mask:  [H,W]   {0,1} (will be resized if needed)
    """
    img = image

    # Ensure mask is 2D
    if mask.ndim == 3:
        mask = np.squeeze(mask)

    h, w = img.shape[:2]

    # If mask and image have different spatial size, resize mask
    if mask.shape != (h, w):
        mask_pil = Image.fromarray(mask.astype(np.uint8))
        mask_pil = mask_pil.resize((w, h), resample=Image.NEAREST)
        mask = np.array(mask_pil)

    # Normalize mask to {0,1} just in case
    if mask.max() > 1:
        mask = (mask > 0).astype(np.float32)

    # Create RGB mask
    mask_rgb = np.zeros_like(img)
    mask_rgb[..., 0] = mask  # red channel = rooftop

    overlay = (1 - alpha) * img + alpha * mask_rgb
    return overlay


def show_sample(dataset, idx=None):
    if idx is None:
        idx = random.randint(0, len(dataset) - 1)
    img, mask = dataset[idx]  # img [3,H,W], mask [1,H,W]

    # Undo normalisation for plotting
    img_np = img.permute(1, 2, 0).numpy()
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)

    # mask: [1, H, W] -> [H, W]
    mask_np = mask.squeeze(0).numpy()

    overlay_np = plot_overlay(img_np, mask_np)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(img_np)
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(overlay_np)
    plt.title("Image with Overlay")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(mask_np, cmap="gray")
    plt.title("Mask (PV potential)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()
In [10]:
show_sample(train_base)
No description has been provided for this image

Summary: Class Distribution¶

You can then compute foreground (FG; rooftop) vs background (BG) pixel counts across the train/val/test splits. Rooftops represent ~16–17% of all pixels, while the remaining ~82–83% is background.

Interesting observations:

  • Splits maintain similar FG/BG ratios (consistent dataset structure)

  • Rooftop pixels are Noteably less frequent (this creates a class imbalance typical for semantic segmentation of small objects)

  • Some images contain large rooftop structures, which explains the slightly higher foreground percentage compared to typical urban aerial datasets

In [ ]:
def compute_basic_stats_geographic(dataset, name="dataset"):
    """
    Compute rooftop pixel fraction statistics for a geographic dataset.

    Args:
        dataset: GenevaRooftopDataset instance
        name: str, name of the dataset for printing purposes

    Returns:
        dict with rooftop fraction statistics

    """
    if len(dataset) == 0:
        raise ValueError(f"Dataset {name} is empty")

    rooftop_fracs = []

    for idx in range(len(dataset)):
        _, mask = dataset[idx]
        rooftop_frac = mask.mean().item()
        rooftop_fracs.append(rooftop_frac)

    rooftop_fracs = np.array(rooftop_fracs)

    print()
    print(f"Rooftop pixel fraction stats for {name} (n={len(dataset)}):")
    print("-" * 20)
    print(f"  mean   : {rooftop_fracs.mean():.3f}")
    print(f"  median : {np.median(rooftop_fracs):.3f}")
    print(f"  min    : {rooftop_fracs.min():.3f}")
    print(f"  max    : {rooftop_fracs.max():.3f}")
    print("-" * 45)

    return {"rooftop_fracs": rooftop_fracs}
In [12]:
train_stats = compute_basic_stats_geographic(train_base, name="train")
test_stats = compute_basic_stats_geographic(test_base, name="test")
Rooftop pixel fraction stats for train (n=423):
--------------------
  mean   : 0.162
  median : 0.131
  min    : 0.001
  max    : 0.972
---------------------------------------------

Rooftop pixel fraction stats for test (n=102):
--------------------
  mean   : 0.170
  median : 0.143
  min    : 0.006
  max    : 0.546
---------------------------------------------

Methodology¶

Few-shot episodes: support and query¶

Few-shot learning is typically implemented with episodic training. An episode is one small training problem that mimics how we will use the model at test time: we see some labelled examples (support) and we must make predictions for new data (query).

In this setting, each episode contains:

  • a support image with its mask (img_s, mask_s), and
  • a query image with its mask (img_q, mask_q).

In the following class, you create an EpisodeDataset that wraps the base training dataset. For every episode, it randomly picks two different tiles from the training set:

  • one becomes the support pair,
  • the other becomes the query pair.

During training, the model:

  1. looks at the support image + mask to learn what rooftops and background look like, and
  2. tries to correctly segment the query image, using that information.

The parameter episodes_per_epoch defines how many such support–query pairs (episodes) we use in one training epoch. A DataLoader iterates over this dataset and gives us one episode at a time for training.

In [ ]:
class EpisodeDataset(Dataset):
    """
    Yields (support_imgs [K,3,H,W], support_masks [K,1,H,W], query_img, query_mask).

    Args:
        base_dataset: GenevaRooftopDataset instance
        episodes_per_epoch: int, number of episodes per epoch
        K: int, number of support examples per episode

    Returns:
        support_imgs: Tensor [K, 3, H, W]
        support_masks: Tensor [K, 1, H, W]
        query_img: Tensor [3, H, W]
        query_mask: Tensor [1, H, W]

    """

    def __init__(self, base_dataset, episodes_per_epoch=2000, K=1):
        self.base = base_dataset
        self.episodes_per_epoch = episodes_per_epoch
        self.n = len(base_dataset)
        self.K = K

    def __len__(self):
        return self.episodes_per_epoch

    def __getitem__(self, idx):
        # sample K support + 1 query index
        indices = random.sample(range(self.n), self.K + 1)
        *support_idx, query_idx = indices  # Python unpacking

        # get support items
        support = [self.base[i] for i in support_idx]
        imgs_s, masks_s = zip(*support, strict=False)  # tuples of tensors

        # get query item
        img_q, mask_q = self.base[query_idx]

        # stack support tensors
        imgs_s = torch.stack(imgs_s, dim=0)  # [K, 3, H, W]
        masks_s = torch.stack(masks_s, dim=0)  # [K, 1, H, W]

        return imgs_s, masks_s, img_q, mask_q
In [14]:
episodes_per_epoch = 2000
episode_dataset = EpisodeDataset(train_base, episodes_per_epoch=episodes_per_epoch)
episode_loader = DataLoader(episode_dataset, batch_size=1, shuffle=True)

Encoder backbone: ResNet18 feature extractor¶

You can now define the feature extractor that underpins the prototypical network. The Encoder class:

  • uses a pretrained ResNet18 from torchvision,
  • keeps the convolutional layers up to layer3, which downsample the input by a factor of 8, and
  • adds a 1×1 convolution to project the ResNet features into a configurable embedding dimension (here 256 channels).

Given an input image of shape [3, H, W], the encoder outputs a feature map [C, H', W']. These feature maps are the basis for computing class prototypes (rooftop vs background) and for classifying each pixel at test time.

Note:
ResNet18 is lightweight and pretrained on ImageNet. We expect a model pretrained on geospatial/satellite images to perform better as feature extractor → can be a good extension.
Question: How does changing the backbone architecture alter the performance?

In [15]:
class Encoder(nn.Module):
    def __init__(self, out_channels=256, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        self.stem = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
        )
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3

        self.proj = nn.Conv2d(128 + 256, out_channels, kernel_size=1)  # 128 from layer2, 256 from layer3

    def forward(self, x):
        x = self.stem(x)
        f1 = self.layer1(x)
        f2 = self.layer2(f1)  # [B,128,H/4,W/4]
        f3 = self.layer3(f2)  # [B,256,H/8,W/8]

        f2_up = F.interpolate(f2, size=f3.shape[-2:], mode="bilinear", align_corners=False)
        f = torch.cat([f2_up, f3], dim=1)  # [B,384,H',W']
        f = self.proj(f)
        return f
In [16]:
encoder = Encoder(out_channels=256).to(device)
print(encoder)
/Users/giocopp/miniconda3/envs/DL-Tutorial/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/Users/giocopp/miniconda3/envs/DL-Tutorial/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Encoder(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (proj): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
)

Comment:
This ResNet backbone model takes 3-channel images and turns them into a compact feature map. The stem (Conv → BatchNorm → ReLU → MaxPool) quickly reduces spatial resolution while extracting low-level features, and the subsequent layers are stacks of residual BasicBlocks that learn increasingly complex patterns. At the end, a 1×1 convolutional projection compresses the combined feature channels (384 → 256) into the final embedding used by the rest of the model.

Prototypical segmentation: prototypes and query classification¶

The core idea of prototypical networks is to represent each class by a prototype vector in feature space and to classify new examples based on their distance to these prototypes. In the following code, you will use Euclidean distance to measure this similarity, as it's straigthforward.

For segmentation this is applied at the pixel level. In this case you have only one object type (as you are segmenting rooftops from background), therefore you are using a 2-class prototypical network.

  • From the support images and their masks, you compute:

    • a foreground prototype (rooftops),
    • and a background prototype, by performing MAP over the encoder feature maps.
  • For a query image, you run the encoder once to obtain features and then:

    • compute the Euclidean distance from each pixel’s feature vector to each prototype,
    • convert distances into logits (negative distances),
    • and reshape to a [1, 2, H', W'] logit map over classes 0 (background) and 1 (rooftop).

These two functions, compute_prototypes and classify_query, implement the prototypical segmentation logic used in both training and evaluation.

In [ ]:
def compute_prototypes(feat_support, mask_support):
    """
    Compute background/foreground prototypes from multiple support images.

    Args:
        feat_support: [K, C, H', W']   (K support images)
        mask_support: [K, 1, H,  W]    (binary masks)

    Returns:
        prototypes [2, C] (0=background, 1=foreground)

    """
    # Downsample mask to feature resolution
    mask_small = F.interpolate(
        mask_support, size=feat_support.shape[2:], mode="nearest"
    )  # Downsamples masks to feature size via nearest neighbor
    mask_fg = (mask_small > 0.5).float()  # [K,1,H',W']
    mask_bg = 1.0 - mask_fg  # [K,1,H',W']

    K, C, Hf, Wf = feat_support.shape

    # Flatten across batch and spatial dims: [K,C,H',W'] -> [C, K*H'*W']
    fs = feat_support.permute(1, 0, 2, 3).contiguous().view(C, -1)  # [C, K*H'*W']
    fg_w = mask_fg.view(1, -1)  # [1, K*H'*W']
    bg_w = mask_bg.view(1, -1)

    eps = 1e-6

    # MAP (masked weighted averages) for foreground
    fg_proto = (fs * fg_w).sum(dim=1) / (fg_w.sum(dim=1) + eps)  # [C]
    # MAP (masked weighted averages) for background
    bg_proto = (fs * bg_w).sum(dim=1) / (bg_w.sum(dim=1) + eps)  # [C]

    prototypes = torch.stack([bg_proto, fg_proto], dim=0)  # [2,C]
    return prototypes


def classify_query(feat_query, prototypes):
    """
    Classify query pixels by distance to prototypes.

    Args:
        feat_query: [1, C, H', W']
        prototypes: [2, C]

    Returns:
        logits [1, 2, H', W']

    """
    B, C, Hq, Wq = feat_query.shape

    # [1,C,H',W'] -> [H'*W', C]
    fq = feat_query.view(C, -1).t()  # [H'*W', C]

    # [2, C]
    protos = prototypes  # [2,C]

    # Compute squared Euclidean distance from each pixel to each prototype
    # torch.cdist expects [B, N, D], so add batch dim
    # fq_batch: [1, H'*W', C], protos_batch: [1, 2, C]
    dists = torch.cdist(fq.unsqueeze(0), protos.unsqueeze(0))  # [1, H'*W', 2]
    dists = dists.squeeze(0)  # [H'*W', 2]
    dists = dists**2

    # Convert distances to similarity logits: negative distance
    logits_flat = -dists  # [H'*W', 2]
    logits = logits_flat.t().view(1, 2, Hq, Wq)  # [1,2,H',W']

    return logits

Evaluation metric: Intersection over Union (IoU)¶

To evaluate segmentation performance you will use Intersection over Union (IoU):

$$ \text{IoU} = \frac{\text{intersection of predicted and true mask}}{\text{union of predicted and true mask}}. $$

Given the predicted logits [1, 2, H, W] and the ground-truth binary mask [1, 1, H, W], you:

  1. take the argmax over classes to obtain a predicted mask,
  2. compute intersection and union between predicted and true rooftop pixels, and
  3. return IoU as a scalar.

This metric will be used to summarise how well your few-shot model segments rooftops on held-out test images.

In [ ]:
def iou_from_logits(logits, target_mask, eps=1e-6):
    """
    Compute IoU given model logits and target mask.

    Args:
        logits: [1, 2, H, W]
        target_mask: [1, H, W] (binary mask)
        eps: small float to avoid division by zero

    Returns:
        iou: float

    """
    # predicted class (0 or 1)
    pred = logits.argmax(dim=1, keepdim=True).float()  # [1,1,H,W]
    target = (target_mask > 0.5).float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection

    iou = (intersection + eps) / (union + eps)
    return iou.item()

Episodic meta-training¶

You can now train the encoder using episodic supervision. For each episode in the training set:

  1. You will obtain a support image and mask and a query image and mask.
  2. You will encode both images using the shared encoder.
  3. From the support features and mask, you will compute foreground and background prototypes.
  4. You will classify each pixel in the query feature map by its distance to these prototypes.
  5. You will build the query’s ground-truth labels at the feature resolution and compute a cross-entropy loss between predicted logits and true classes.

This process makes the encoder to learn a feature space in which simple prototype-based classification works well across many different episodes. After a few epochs, the encoder can be reused for 1-shot or K-shot segmentation on unseen test images.

In [ ]:
optimizer = torch.optim.Adam(encoder.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

def meta_train(num_epochs=7):
    """
    Meta-train the encoder using episodic training.

    Args:
        num_epochs: int, number of epochs to train

    Returns:
        loss_history: list of average episode losses per epoch

    """
    loss_history = []
    for epoch_idx in range(1, num_epochs + 1):
        encoder.train()
        epoch_loss_sum = 0.0

        for support_img, support_mask, query_img, query_mask in episode_loader:
            support_img_bchw = support_img.squeeze(0).to(device)    # [1,3,H,W] -> [3,H,W] or [1,3,H,W] depending on loader
            support_mask_b1hw = support_mask.squeeze(0).to(device)  # [1,1,H,W] -> [1,H,W] or similar

            query_img_bchw = query_img.to(device)                   # usually [B,3,H,W]
            query_mask_b1hw = query_mask.to(device)                 # usually [B,1,H,W]

            optimizer.zero_grad(set_to_none=True)

            support_feats = encoder(support_img_bchw)
            query_feats = encoder(query_img_bchw)

            prototypes_2c = compute_prototypes(support_feats, support_mask_b1hw)
            query_logits_small = classify_query(query_feats, prototypes_2c)

            query_logits = F.interpolate(
                query_logits_small,
                size=query_mask_b1hw.shape[-2:],
                mode="bilinear",
                align_corners=False,
            )

            query_target = query_mask_b1hw.long().squeeze(1)  # [B,1,H,W] -> [B,H,W]
            episode_loss = F.cross_entropy(query_logits, query_target)

            episode_loss.backward()
            optimizer.step()

            epoch_loss_sum += float(episode_loss.item())

        avg_epoch_loss = epoch_loss_sum / len(episode_loader)
        print(f"Epoch {epoch_idx}/{num_epochs} | avg episode loss: {avg_epoch_loss:.4f}")
        loss_history.append(avg_epoch_loss)

        scheduler.step()

    return loss_history
In [20]:
loss_history = meta_train(num_epochs=7)
Epoch 1/7 | avg episode loss: 0.2512
Epoch 2/7 | avg episode loss: 0.1196
Epoch 3/7 | avg episode loss: 0.0999
Epoch 4/7 | avg episode loss: 0.1023
Epoch 5/7 | avg episode loss: 0.0853
Epoch 6/7 | avg episode loss: 0.0689
Epoch 7/7 | avg episode loss: 0.0662
In [21]:
## Plot of meta learner performance

loss_history = np.array(loss_history)
epochs = range(1, len(loss_history) + 1)

plt.figure()
plt.plot(epochs, loss_history, marker="o")
plt.xlabel("Epoch")
plt.ylabel("Average episode loss")
plt.title("Meta-training loss history")
plt.grid(True)
plt.show()
No description has been provided for this image

Interpretation of the average episode loss
The “avg episode loss” at each epoch is the average cross-entropy error over all support–query tasks in that epoch; the fact that it steadily goes from ~0.25 to ~0.07 means that the encoder is successfully learning a feature space where prototype-based segmentation works increasingly well. The loss is still minimally decreasing, meaning that we could train more to have marginally better results. For the scope of the tutorial, we can stop at epoch 7.

Recap what episodes are
An episode is one small training problem that mimics how we will use the model at test time. For one episode, we:

  • Take (support_image, support_mask, query_image, query_mask).
  • Encode support → compute prototypes (foreground/background).
  • Encode query → classify each pixel using the prototypes.
  • Compare predicted class vs true class at each pixel of the query.
  • Compute a cross-entropy loss over all those pixels.

Few-shot inference: 1-shot and K-shot helpers¶

After meta-training, comes meta-learning. You want to use the encoder for few-shot segmentation on new data.

For that, you define a function where:

  • k_shot_predict takes:
    • K support images and masks, and
    • a query image, and returns the full-resolution logits [1, 2, H, W] for the query.
  • Internally, it encodes supports and query, computes prototypes using all K supports, classifies the query at the feature resolution, and upsamples the logits back to the original image size.

For convenience, one_shot_predict wraps this function for the special case K = 1. These utilities are used both for quantitative evaluation and for visualising qualitative results.

In [ ]:
def k_shot_predict(encoder, support_imgs, support_masks, query_img):
    """
    K-shot segmentation for a query image given K support images+masks.

    Args:
        encoder: Encoder model
        support_imgs:  [K, 3, H, W]
        support_masks: [K, 1, H, W]
        query_img:     [3, H, W]

    Returns:
        logits [1, 2, H, W]

    """
    encoder.eval()
    with torch.no_grad():
        support_imgs = support_imgs.to(device)
        support_masks = support_masks.to(device)
        query_img = query_img.to(device).unsqueeze(0)

        # Pass through encoder
        feat_s = encoder(support_imgs)
        feat_q = encoder(query_img)

        # Compute prototypes
        prototypes = compute_prototypes(feat_s, support_masks)

        # Classify query pixels
        logits_small = classify_query(feat_q, prototypes)
        logits = F.interpolate(
            logits_small,
            size=(query_img.shape[2], query_img.shape[3]),
            mode="bilinear",
            align_corners=False,
        )

    return logits.cpu()


def one_shot_predict(encoder, support_img, support_mask, query_img):
    """
    1-shot helper that wraps single support into K=1 form.

    Args:
        encoder: Encoder model
        support_img:  [3, H, W]
        support_mask: [1, H, W]
        query_img:    [3, H, W]

    Returns:
        logits [1, 2, H, W]

    """
    support_imgs = support_img.unsqueeze(0)
    support_masks = support_mask.unsqueeze(0)
    return k_shot_predict(encoder, support_imgs, support_masks, query_img)

Results¶

Quantitative evaluation: Zero-, One-, and K-shot IoU¶

To assess the model, you evaluate K-shot IoU on the test split:

  • For each test image (query), you randomly sample K distinct support images from the training split.
  • You run k_shot_predict to obtain predicted logits and compute IoU with respect to the true rooftop mask.
  • You repeat this process for multiple random test queries and report the mean and standard deviation of IoU.

By varying K you can study how performance improves as you provide more labelled examples to the model.

In [ ]:
def evaluate_kshot_iou(
    encoder,
    test_dataset,
    train_dataset=None,
    K=5,
    num_samples=None,
    seed=42,
    replace_test=False,  # when num_samples is set: sample with replacement?
):
    """
    Evaluate K-shot IoU on test_dataset.

    If K == 0: zero-shot evaluation (no support).
    If K  > 0: K-shot evaluation using train_dataset as support pool.

    Behavior for choosing queries:
      - num_samples is None  -> evaluate ALL test samples exactly once (no duplicates)
      - num_samples is int   -> randomly sample num_samples from test_dataset
                               (replace_test controls with/without replacement)
    """
    encoder.eval()
    rng = np.random.default_rng(seed)

    n_test = len(test_dataset)

    # --- Validate inputs ---
    if K > 0:
        if train_dataset is None:
            raise ValueError("train_dataset must be provided when K > 0")
        if K > len(train_dataset):
            raise ValueError(f"K={K} is larger than train_dataset size={len(train_dataset)}")

    # --- Choose which test indices to evaluate ---
    if num_samples is None:
        indices = np.arange(n_test)  # each test example exactly once
    elif (not replace_test) and num_samples <= n_test:
        indices = rng.choice(n_test, size=num_samples, replace=False)
    else:
        indices = rng.integers(0, n_test, size=num_samples)

    ious = []
    with torch.no_grad():
        for ti in indices:
            img_q, mask_q = test_dataset[int(ti)]

            if K == 0:
                logits = zero_shot_predict(encoder, img_q)  # expected [1,2,H,W] @giorgio pls check
            else:
                support_indices = rng.choice(len(train_dataset), size=K, replace=False)
                support_imgs, support_masks = [], []

                for si in support_indices:
                    img_s, mask_s = train_dataset[int(si)]
                    support_imgs.append(img_s)
                    support_masks.append(mask_s)

                support_imgs = torch.stack(support_imgs, dim=0)
                support_masks = torch.stack(support_masks, dim=0)

                logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)

            iou = iou_from_logits(logits, mask_q.unsqueeze(0))
            # make sure it's a plain float for numpy stats
            if torch.is_tensor(iou):
                iou = float(iou.detach().cpu().item())
            ious.append(iou)

    ious = np.asarray(ious, dtype=float)

    tag = "0-shot" if K == 0 else f"{K}-shot"
    n_eval = len(indices)
    n_unique = len(np.unique(indices))

    print(
        f"{tag} mean IoU over {n_eval} test samples (unique={n_unique}, test_size={n_test}): "
        f"{ious.mean():.3f} (±{ious.std():.3f})"
    )
    return ious

Visualising a 1-shot prediction¶

In [58]:
def tensor_to_rgb(img_tensor):
    """Undo normalisation and convert [3,H,W] tensor to [H,W,3] RGB numpy."""
    img_np = img_tensor.detach().cpu().permute(1, 2, 0).numpy()
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)
    return img_np


def visualise_kshot_example(encoder, train_dataset, test_dataset, K=5):
    encoder.eval()
    rng = np.random.default_rng()

    # pick query from test
    ti = rng.integers(0, len(test_dataset))
    img_q, mask_q = test_dataset[ti]

    # pick K supports from train
    support_indices = rng.choice(len(train_dataset), size=K, replace=False)
    support_imgs, support_masks = [], []
    for si in support_indices:
        img_s, mask_s = train_dataset[si]
        support_imgs.append(img_s)
        support_masks.append(mask_s)
    support_imgs = torch.stack(support_imgs, dim=0)  # [K,3,H,W]
    support_masks = torch.stack(support_masks, dim=0)  # [K,1,H,W]

    # prediction
    logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)
    pred_mask = logits.argmax(dim=1, keepdim=True).float().squeeze().numpy()

    img_q_np = tensor_to_rgb(img_q)
    mask_q_np = mask_q.squeeze(0).numpy()

    iou = iou_from_logits(logits, mask_q.unsqueeze(0))

    # Plot
    cols = max(K, 3)
    plt.figure(figsize=(4 * cols, 8))

    # first row: support images
    for i in range(K):
        plt.subplot(2, cols, i + 1)
        plt.imshow(tensor_to_rgb(support_imgs[i]))
        plt.title(f"Support {i+1} image")
        plt.axis("off")

    # second row: support masks
    for i in range(K):
        plt.subplot(2, cols, cols + i + 1)
        plt.imshow(support_masks[i].squeeze(0).numpy(), cmap="gray")
        plt.title(f"Support {i+1} mask")
        plt.axis("off")

    # query image & masks to the right (replace last columns if needed)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(img_q_np)
    plt.title("Query image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(mask_q_np, cmap="gray")
    plt.title("Ground Truth mask (Query)")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask, cmap="gray")
    plt.title(f"Predicted mask ({K}-shot)\nIoU: {iou:.3f}")
    plt.axis("off")

    plt.tight_layout()
    plt.show()
In [59]:
visualise_kshot_example(encoder, train_base, test_base, K=1)
No description has been provided for this image
No description has been provided for this image

Visualising a 5-shot prediction¶

Finally, you can visualise a K-shot episode (e.g. 5-shot) to complement the numerical results. For that:

  1. Select a random query image from the test set.
  2. Select K random support images from the training set.
  3. Run k_shot_predict to produce a predicted mask.
  4. Plot:
    • the query image,
    • the ground-truth query mask,
    • and the K-shot predicted mask.

Qualitatively comparing this to the 1-shot visualisation help you to see where additional support images improve the segmentation—typically in challenging cases such as partially occluded rooftops or unusual roof materials.

In [72]:
visualise_kshot_example(encoder, train_base, test_base, K=5)
No description has been provided for this image
No description has been provided for this image

We can also test how a zero shot prediction looks like:

Zero-Shot Prediction¶

Unlike few-shot learning where we have support examples, zero-shot prediction attempts to segment the query image without any labeled examples. Since we don't have ground-truth prototypes, we use a simple heuristic:

  • Background prototype: computed as the global average of all query features
  • Foreground prototype: computed from high-activation features (top 10th percentile)

Visualizing the results of this naive approach and comparing it with the few-shot results demonstrate what is the effect of learning from few examples compared to having no benchmark at all.

In [ ]:
def zero_shot_predict(encoder, query_img):
    """
    Zero-shot segmentation for a query image without any support examples.

    This uses a simple heuristic: create prototypes from the query features themselves.
    - Background prototype: global average of all query features
    - Foreground prototype: use high-activation features (top percentile)

    Args:
        encoder: trained Encoder model
        query_img: [3, H, W]

    Returns:
        logits [1, 2, H, W]

    """
    encoder.eval()
    with torch.no_grad():
        query_img = query_img.to(device).unsqueeze(0)  # [1, 3, H, W]

        # Extract features
        feat_q = encoder(query_img)  # [1, C, H', W']

        # Create default prototypes from query features
        # C = feat_q.shape[1]

        # Background prototype: global average
        proto_bg = feat_q.mean(dim=[2, 3])  # [1, C]

        # Foreground prototype: use features with high activation
        # Use top 10% of spatially-averaged channel activations
        channel_activations = feat_q.mean(dim=[2, 3])  # [1, C]
        threshold = torch.quantile(channel_activations, 0.9)
        proto_fg = feat_q.clone()
        proto_fg[:, channel_activations.squeeze() < threshold] = 0
        proto_fg = proto_fg.mean(dim=[2, 3])  # [1, C]

        # Stack prototypes: [2, C]
        prototypes = torch.cat([proto_bg, proto_fg], dim=0)  # [2, C]

        # Classify query pixels using the same function as k-shot
        logits_small = classify_query(feat_q, prototypes)
        logits = F.interpolate(
            logits_small,
            size=(query_img.shape[2], query_img.shape[3]),
            mode="bilinear",
            align_corners=False,
        )

    return logits.cpu()
In [84]:
# Test zero-shot prediction on a random test image
idx = random.randrange(len(test_base))
img_q, mask_q = test_base[idx]

# Run zero-shot prediction
logits_0shot = zero_shot_predict(encoder, img_q)

# Get predicted mask
pred_mask = logits_0shot.argmax(dim=1).squeeze(0).numpy()

# Compute IoU
iou = iou_from_logits(logits_0shot, mask_q.unsqueeze(0))

# Convert to numpy for visualization
img_q_np = tensor_to_rgb(img_q)
mask_q_np = mask_q.squeeze(0).numpy()

# Visualize
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img_q_np)
plt.title("Query image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(mask_q_np, cmap="gray")
plt.title("Ground Truth mask")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(pred_mask, cmap="gray")
plt.title(f"Predicted mask (0-shot)\nIoU: {iou:.3f}")
plt.axis("off")

plt.tight_layout()
plt.show()
No description has been provided for this image

Comparing performance¶

Using the evaluation function above, you can now compare:

  • 0-shot IOU
  • 1-shot IoU
  • 5-shot IoU
  • 10-shot IoU
  • 20-shot IoU

You would expect performance to improve with more support examples, as the prototypes become more representative of the diversity of rooftop appearances. This experiment illustrates a key trade-off in few-shot learning between annotation cost (how many labelled examples we need) and model performance.

In [88]:
ious_0shot = evaluate_kshot_iou(encoder, test_base, K=0)
ious_1shot = evaluate_kshot_iou(encoder, test_base, train_base, K=1)
ious_5shot = evaluate_kshot_iou(encoder, test_base, train_base, K=5)
ious_10shot = evaluate_kshot_iou(encoder, test_base, train_base, K=10)
ious_20shot = evaluate_kshot_iou(encoder, test_base, train_base, K=20)
0-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.334 (±0.220)
1-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.459 (±0.204)
5-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.485 (±0.198)
10-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.484 (±0.193)
20-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.483 (±0.190)

Optional: Testing of Results¶

Note:
Feel free to skip this section, as it include diagnostics of the model. However, if you are interested in understanding how to assess performances and look at some nice visualizations of how the model behaves, we encourage you to keep reading through!

Testing the encoder: F1 on train and test data¶

Since we have a very low loss for the encoding model, this implies high accuracy on training episodes (the model is getting most pixels right during training). We can check accuracy of the encoder directly.

In [ ]:
def _to_hw_mask(mask, thr=0.5):
    """Convert mask to bool [H,W] given various possible input shapes."""
    # mask: [H,W] or [1,H,W] or [B,1,H,W] -> returns bool [H,W]
    if isinstance(mask, torch.Tensor):
        m = mask
    else:
        m = torch.tensor(mask)
    m = m.detach().cpu()
    while m.dim() > 2:
        m = m.squeeze(0)
    return m > thr


def _f1_from_masks(pred_bool, gt_bool, eps=1e-8):
    """Compute F1 score from predicted and ground truth boolean masks."""
    pred = pred_bool.flatten()
    gt = gt_bool.flatten()

    tp = (pred & gt).sum().item()
    fp = (pred & ~gt).sum().item()
    fn = (~pred & gt).sum().item()

    denom = 2 * tp + fp + fn
    if denom == 0:
        # both empty -> perfect
        return 1.0
    return (2 * tp) / (denom + eps)


def evaluate_f1_one_shot(
    encoder,
    support_dataset,
    query_dataset,
    device,
    support_idx=0,
    thr=0.5,
    max_queries=None,
    skip_support_in_train=True,
):
    """
    Evaluate 1-shot F1 on query_dataset using a single support from support_dataset.

    Args:
        encoder: trained Encoder model
        support_dataset: GenevaRooftopDataset for support selection
        query_dataset: GenevaRooftopDataset for query evaluation
        device: torch device
        support_idx: int, index of the support example in support_dataset
        thr: float, threshold for binarizing masks
        max_queries: int or None, max number of queries to evaluate
        skip_support_in_train: bool, if True and support_dataset is query_dataset,
                               skip the support example during evaluation
    Returns:
        mean_f1: float, mean F1 over evaluated queries
        micro_f1: float, micro-averaged F1 over evaluated queries
        n_evaluated: int, number of queries evaluated

    """
    encoder.eval()

    img_s, mask_s = support_dataset[support_idx]

    n = len(query_dataset) if max_queries is None else min(max_queries, len(query_dataset))

    f1_list = []
    TP = FP = FN = 0

    with torch.no_grad():
        for i in range(n):
            if skip_support_in_train and (query_dataset is support_dataset) and (i == support_idx):
                continue

            img_q, mask_q = query_dataset[i]

            # your existing predictor: returns [1,2,H,W]
            logits = one_shot_predict(encoder, img_s, mask_s, img_q)
            pred = logits.argmax(dim=1)[0]  # [H,W] values 0/1

            pred_bool = (pred > 0).cpu().bool()
            gt_bool = _to_hw_mask(mask_q, thr=thr).bool()

            f1_list.append(_f1_from_masks(pred_bool, gt_bool))

            p = pred_bool.flatten()
            g = gt_bool.flatten()
            TP += (p & g).sum().item()
            FP += (p & ~g).sum().item()
            FN += (~p & g).sum().item()

    mean_f1 = sum(f1_list) / max(1, len(f1_list))
    micro_f1 = (2 * TP) / (2 * TP + FP + FN + 1e-8) if (2 * TP + FP + FN) > 0 else 1.0
    return mean_f1, micro_f1, len(f1_list)


def evaluate_f1_multiple_supports(
    encoder,
    train_base,
    test_base,
    device,
    K=1,
    seed=42,
    thr=0.5,
    max_queries=None,
):
    """
    Evaluate 1-shot F1 over multiple random supports from train_base.

    Args:
        encoder: trained Encoder model
        train_base: GenevaRooftopDataset for support selection
        test_base: GenevaRooftopDataset for query evaluation
        device: torch device
        K: int, number of random supports to evaluate
        seed: int, random seed for support selection
        thr: float, threshold for binarizing masks
        max_queries: int or None, max number of queries to evaluate per support

    Returns:
        support_idxs: list of selected support indices
        train_scores: list of (mean_f1, micro_f1) on train_base per support
        test_scores: list of (mean_f1, micro_f1) on test_base per support

    """
    rng = np.random.default_rng(seed)
    support_idxs = rng.choice(len(train_base), size=min(K, len(train_base)), replace=False).tolist()

    train_scores = []
    test_scores = []

    for sidx in support_idxs:
        tr_mean, tr_micro, _ = evaluate_f1_one_shot(
            encoder,
            support_dataset=train_base,
            query_dataset=train_base,
            device=device,
            support_idx=sidx,
            thr=thr,
            max_queries=max_queries,
            skip_support_in_train=True,
        )
        te_mean, te_micro, _ = evaluate_f1_one_shot(
            encoder,
            support_dataset=train_base,
            query_dataset=test_base,
            device=device,
            support_idx=sidx,
            thr=thr,
            max_queries=max_queries,
            skip_support_in_train=False,
        )
        train_scores.append((tr_mean, tr_micro))
        test_scores.append((te_mean, te_micro))

    train_mean_f1 = float(np.mean([x[0] for x in train_scores]))
    train_micro_f1 = float(np.mean([x[1] for x in train_scores]))
    test_mean_f1 = float(np.mean([x[0] for x in test_scores]))
    test_micro_f1 = float(np.mean([x[1] for x in test_scores]))

    print("Support indices:", support_idxs)
    print(f"Train (avg over supports): mean F1={train_mean_f1:.4f} | micro F1={train_micro_f1:.4f}")
    print(f"Test  (avg over supports): mean F1={test_mean_f1:.4f} | micro F1={test_micro_f1:.4f}")

    return support_idxs, train_scores, test_scores
In [107]:
support_idxs, train_scores, test_scores = evaluate_f1_multiple_supports(
    encoder=encoder,
    train_base=train_base,
    test_base=test_base,
    device=device,
    K=1,
    seed=42,
    thr=0.5,
    max_queries=50,
)

print("support_idxs:", support_idxs)
Support indices: [37]
Train (avg over supports): mean F1=0.8763 | micro F1=0.8864
Test  (avg over supports): mean F1=0.6583 | micro F1=0.6770
support_idxs: [37]

Comment:
The encoder achieves strong performance on training data (F1 ≈ 0.87) but shows a Noteable drop on the test set (F1 ≈ 0.65), indicating some degree of overfitting to the training distribution, and therefore difficulty to generalize correctly. This performance gap reflects the geographic domain shift between the training regions (northern and southern Geneva) and the test region (central Geneva), suggesting that rooftop patterns vary across different parts of the city. The model might be too simple to effectively perform the "transfer" the same quality of learning achieved in the training set to the test set.

Testing the encoder: IoU on train and test data¶

In [100]:
train_ious = evaluate_kshot_iou(encoder, train_base, train_base, K=1, num_samples=None)
test_ious = evaluate_kshot_iou(encoder, test_base, train_base, K=1, num_samples=None)
1-shot mean IoU over 423 test samples (unique=423, test_size=423): 0.738 (±0.161)
1-shot mean IoU over 102 test samples (unique=102, test_size=102): 0.449 (±0.204)

Comment:
The one-shot IoU drops significantly from ~0.76 on training samples to ~0.48 on test samples, confirming that the model struggles to generalize across regions. While the test IoU is still reasonably good for few-shot segmentation, the gap suggests that the encoder learns region-specific features rather than fully invariant rooftop representations.

Visualizing encoder’s pixel embeddings¶

In [ ]:
def _collect_pixel_embeddings(encoder, dataset, device, n_images=None, thr=0.5):
    """
    Collect foreground and background pixel embeddings from a dataset.

    Args:
        encoder: Encoder model
        dataset: GenevaRooftopDataset instance
        device: torch device
        n_images: int or None, number of images to process (None = all)
        thr: float, threshold to separate FG/BG in mask

    Returns:
      fg_all: numpy array [N_fg, C]
      bg_all: numpy array [N_bg, C]

    """
    encoder.eval()
    fg_embeddings, bg_embeddings = [], []

    n = len(dataset) if (n_images is None) else min(n_images, len(dataset))

    with torch.no_grad():
        for i in range(n):
            img, mask = dataset[i]

            img = img.to(device)
            mask = mask.to(device)

            feat = encoder(img.unsqueeze(0))  # [1, C, h, w]
            _, C, h, w = feat.shape

            # mask -> [1,1,H,W] then downsample to [1,1,h,w]
            if mask.dim() == 2:
                mask_ = mask.unsqueeze(0).unsqueeze(0)
            elif mask.dim() == 3:
                mask_ = mask.unsqueeze(0)
            else:
                mask_ = mask
            mask_ = mask_.float()
            mask_small = F.interpolate(mask_, size=(h, w), mode="nearest")  # [1,1,h,w]

            # flatten
            feat_flat = feat.squeeze(0).view(C, -1).t()  # [h*w, C]
            mask_flat = mask_small.squeeze(0).squeeze(0).view(-1)  # [h*w]

            fg = mask_flat > thr
            bg = ~fg

            if fg.any():
                fg_embeddings.append(feat_flat[fg].cpu())
            if bg.any():
                bg_embeddings.append(feat_flat[bg].cpu())

    if len(fg_embeddings) == 0 or len(bg_embeddings) == 0:
        raise ValueError("No FG or BG pixels collected. Check masks/threshold.")

    fg_all = torch.cat(fg_embeddings, dim=0).numpy()
    bg_all = torch.cat(bg_embeddings, dim=0).numpy()
    return fg_all, bg_all


train_fg, train_bg = _collect_pixel_embeddings(encoder, train_base, device, n_images=None)
test_fg, test_bg = _collect_pixel_embeddings(encoder, test_base, device, n_images=None)
In [ ]:
from sklearn.manifold import TSNE

def tsne_train_test(train_fg, train_bg, test_fg, test_bg, k=500, seed=42, perplexity=None):
    """
    Perform t-SNE on sampled pixel embeddings from train/test and FG/BG.

    Args:
        train_fg: numpy array [N_train_fg, C]
        train_bg: numpy array [N_train_bg, C]
        test_fg: numpy array [N_test_fg, C]
        test_bg: numpy array [N_test_bg, C]
        k: int, number of points to sample from each category
        seed: int, random seed
        perplexity: int or None, t-SNE perplexity (None = auto choose)

    Returns:
        None (plots t-SNE)

    """
    rng = np.random.default_rng(seed)

    def sample(arr, k):
        k = min(k, len(arr))
        idx = rng.choice(len(arr), k, replace=False)
        return arr[idx]

    # sample points
    tr_fg = sample(train_fg, k)
    tr_bg = sample(train_bg, k)
    te_fg = sample(test_fg, k)
    te_bg = sample(test_bg, k)

    X = np.vstack([tr_fg, tr_bg, te_fg, te_bg])
    y = np.array([1] * len(tr_fg) + [0] * len(tr_bg) + [1] * len(te_fg) + [0] * len(te_bg))  # 1=roof, 0=bg
    d = np.array([0] * (len(tr_fg) + len(tr_bg)) + [1] * (len(te_fg) + len(te_bg)))  # 0=train, 1=test

    n = len(X)
    print(
        f"Using points: train_fg={len(tr_fg)}, train_bg={len(tr_bg)}, test_fg={len(te_fg)}, test_bg={len(te_bg)} | total={n}"
    )

    # choose a safe perplexity
    if perplexity is None:
        perplexity = max(2, min(10, (n - 1) // 3))  # good default for small n
    perplexity = min(perplexity, n - 1)
    if perplexity < 2:
        raise ValueError(f"Not enough points for t-SNE (total={n}). Use PCA or collect more points.")

    print("perplexity =", perplexity)

    Z = TSNE(
        n_components=2,
        perplexity=perplexity,
        random_state=seed,
        init="pca",
        learning_rate="auto",
    ).fit_transform(X)

    plt.figure(figsize=(8, 6))
    # Background
    plt.scatter(
        Z[(y == 0) & (d == 0), 0],
        Z[(y == 0) & (d == 0), 1],
        s=12,
        alpha=0.6,
        label="Background train",
        marker="o",
        c="green",
    )
    plt.scatter(
        Z[(y == 0) & (d == 1), 0],
        Z[(y == 0) & (d == 1), 1],
        s=20,
        alpha=0.6,
        label="Background test",
        marker="x",
        c="green",
    )
    # Rooftop / Foreground
    plt.scatter(
        Z[(y == 1) & (d == 0), 0],
        Z[(y == 1) & (d == 0), 1],
        s=12,
        alpha=0.6,
        label="Rooftop train",
        marker="o",
        c="orange",
    )
    plt.scatter(
        Z[(y == 1) & (d == 1), 0],
        Z[(y == 1) & (d == 1), 1],
        s=20,
        alpha=0.6,
        label="Rooftop test",
        marker="x",
        c="orange",
    )

    plt.legend()
    plt.title("t-SNE fit (train vs test)")
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.show()
In [111]:
tsne_train_test(train_fg, train_bg, test_fg, test_bg, k=1000, seed=42)
Using points: train_fg=1000, train_bg=1000, test_fg=1000, test_bg=1000 | total=4000
perplexity = 10
No description has been provided for this image

Comment:
The t-SNE visualization reduces the high-dimensional pixel embeddings to 2D, revealing the structure of the learned feature space. The clear separation between rooftop (orange) and background (green) clusters demonstrates that the encoder successfully learns discriminative features for semantic segmentation. The mixing of train and test samples within each cluster (circles vs. X markers) is encouraging, as it suggests the prototype-based classification can work across geographic domains. However, some separation between train and test distributions is still visible, which partially explains the performance gap observed in the testing results.

Which cases are the easiest and hardest, and why?¶

We analyze the 5-shot learning case.

In [ ]:
def _to_bool_mask(mask, thr=0.5):
    """Convert mask to boolean [H,W] tensor."""
    m = mask.detach().cpu()
    while m.dim() > 2:
        m = m.squeeze(0)
    return m > thr


def _f1_from_masks_5_shot(pred_bool, gt_bool, eps=1e-8):
    """Compute F1 score from boolean prediction and ground truth masks."""
    pred = pred_bool.flatten()
    gt = gt_bool.flatten()
    tp = (pred & gt).sum().item()
    fp = (pred & ~gt).sum().item()
    fn = (~pred & gt).sum().item()
    denom = 2 * tp + fp + fn
    return 1.0 if denom == 0 else (2 * tp) / (denom + eps)


def visualize_cases_kshot(
    encoder,
    support_dataset,
    query_dataset,
    support_indices,
    device,
    mode="hard",  # "hard" or "easy"
    n_cases=5,
    max_queries=None,
    thr=0.5,
):
    """
    Visualize hardest/easiest cases in query_dataset given K-shot support.

    Args:
        encoder: trained Encoder model
        support_dataset: GenevaRooftopDataset for support selection
        query_dataset: GenevaRooftopDataset for query evaluation
        support_indices: list of int, indices of support examples in support_dataset
        device: torch device
        mode: str, "hard" or "easy" to select hardest/easiest cases by F1
        n_cases: int, number of cases to visualize
        max_queries: int or None, max number of queries to evaluate
        thr: float, threshold for binarizing masks

    Returns:
        None (plots the selected cases)

    """
    assert mode in ["hard", "easy"]

    encoder.eval()

    # stack supports (what your k_shot_predict expects)
    support_imgs = torch.stack([support_dataset[i][0] for i in support_indices], dim=0)  # [K,C,H,W]
    support_masks = torch.stack([support_dataset[i][1] for i in support_indices], dim=0)  # [K,H,W] or [K,1,H,W]
    if support_masks.dim() == 3:
        support_masks = support_masks.unsqueeze(1)  # [K,1,H,W]

    n = len(query_dataset) if max_queries is None else min(max_queries, len(query_dataset))

    rows = []
    with torch.no_grad():
        for qidx in range(n):
            img_q, mask_q = query_dataset[qidx]

            logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)  # [1,2,H,W]
            if logits.dim() == 3:
                logits = logits.unsqueeze(0)
            logits = logits.cpu()

            prob_fg = F.softmax(logits, dim=1)[0, 1]  # [H,W]
            pred = logits.argmax(dim=1)[0].bool()  # [H,W]
            gt = _to_bool_mask(mask_q, thr=thr).bool()

            f1 = _f1_from_masks_5_shot(pred, gt)

            # uncertainty: mean entropy
            p = prob_fg.clamp(1e-8, 1 - 1e-8)
            entropy = -(p * torch.log(p) + (1 - p) * torch.log(1 - p))
            rows.append((qidx, f1, float(entropy.mean().item())))

    # sort by F1
    rows.sort(key=lambda x: x[1], reverse=(mode == "easy"))
    picked = rows[: min(n_cases, len(rows))]

    print(f"{'Easiest' if mode=='easy' else 'Hardest'} cases by F1 (K={len(support_indices)}):")
    for qidx, f1, _ent in picked:
        print(f"  idx={qidx:4d} | F1={f1:.3f}")

    # visualize picked
    fig, axes = plt.subplots(len(picked), 6, figsize=(24, 4 * len(picked)))
    if len(picked) == 1:
        axes = np.array([axes])

    with torch.no_grad():
        for r, (qidx, f1, _ent_mean) in enumerate(picked):
            img_q, mask_q = query_dataset[qidx]
            gt = _to_bool_mask(mask_q, thr=thr).bool()

            logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)
            if logits.dim() == 3:
                logits = logits.unsqueeze(0)
            logits = logits.cpu()

            prob_fg = F.softmax(logits, dim=1)[0, 1]
            margin = logits[0, 1] - logits[0, 0]
            pred = logits.argmax(dim=1)[0].bool()

            # error overlay: FP=+1 (red), FN=-1 (blue)
            fp = (pred & ~gt).float()
            fn = (~pred & gt).float()
            err = (fp - fn).numpy()

            # uncertainty map: entropy
            p = prob_fg.clamp(1e-8, 1 - 1e-8)
            entropy = (-(p * torch.log(p) + (1 - p) * torch.log(1 - p))).numpy()

            # margin scaling for display
            margin_np = margin.numpy()
            m = np.percentile(np.abs(margin_np), 99) + 1e-8

            axes[r, 0].imshow(tensor_to_rgb(img_q))
            axes[r, 0].set_title(f"RGB idx={qidx}")
            axes[r, 0].axis("off")
            axes[r, 1].imshow(gt.numpy(), cmap="gray")
            axes[r, 1].set_title("GT")
            axes[r, 1].axis("off")
            axes[r, 2].imshow(pred.numpy(), cmap="gray")
            axes[r, 2].set_title(f"Pred (F1={f1:.3f})")
            axes[r, 2].axis("off")

            axes[r, 3].imshow(tensor_to_rgb(img_q))
            axes[r, 3].imshow(err, cmap="bwr", alpha=0.6, vmin=-1, vmax=1)
            axes[r, 3].set_title("Error (red=FP, blue=FN)")
            axes[r, 3].axis("off")

            axes[r, 4].imshow(tensor_to_rgb(img_q))
            axes[r, 4].imshow(entropy, cmap="hot", alpha=0.6)
            axes[r, 4].set_title("Uncertainty (entropy)")
            axes[r, 4].axis("off")

            axes[r, 5].imshow(tensor_to_rgb(img_q))
            axes[r, 5].imshow(margin_np, cmap="bwr", alpha=0.6, vmin=-m, vmax=m)
            axes[r, 5].set_title("Evidence (FG−BG margin)")
            axes[r, 5].axis("off")

    plt.tight_layout()
    plt.show()
    return picked
In [117]:
rng = np.random.default_rng(42)
support_indices = rng.choice(len(train_base), size=5, replace=False).tolist()

easy = visualize_cases_kshot(encoder, train_base, test_base, support_indices, device, mode="easy", n_cases=5)
Easiest cases by F1 (K=5):
  idx=  75 | F1=0.893
  idx=  23 | F1=0.874
  idx=  85 | F1=0.869
  idx=   8 | F1=0.860
  idx=   4 | F1=0.858
No description has been provided for this image
In [57]:
rng = np.random.default_rng(42)
support_indices = rng.choice(len(train_base), size=5, replace=False).tolist()

hard = visualize_cases_kshot(encoder, train_base, test_base, support_indices, device, mode="hard", n_cases=5)
Hardest cases by F1 (K=5):
  idx= 101 | F1=0.004
  idx=  33 | F1=0.020
  idx=  51 | F1=0.101
  idx=  30 | F1=0.111
  idx=  65 | F1=0.171
No description has been provided for this image

Comment:
The visualization reveals distinct patterns between easy and hard cases. Easy cases (F1 > 0.85) typically feature well-defined rooftop boundaries, consistent lighting, and rooftop types similar to the training distribution, resulting in low entropy (high confidence) predictions. In contrast, hard cases (F1 < 0.05) often fail completely, with the error maps (column 4) showing either massive false positives (red) or false negatives (blue). Common failure modes include:

(1) rooftops with unusual materials or shapes not seen in training,
(2) heavy shadows or occlusions obscuring roof boundaries or where rooftops are not clearly recognizable from other structures like parking lots,
(3) complex multi-level structures where the support examples lack similar architectural patterns, and
(4) tiles with very small PV compatible rooftop areas where the model struggles to build reliable prototypes from limited support pixels.

The variable entropy values in hard cases suggest that model uncertainty alone is not a reliable indicator of failure, as some catastrophic failures show low entropy, indicating overconfident but incorrect predictions.

Discussion of Results¶

To summarise, in this tutorial, we monitored two key indicators: the meta-training loss and the IoU of the predicted masks.

1. Meta-Training Loss¶

Across epochs, this loss decreases steadily, indicating that:

  • the encoder learns a progressively more meaningful feature representation
  • rooftop vs. non-rooftop pixels become more separable in feature space
  • prototype-based segmentation improves throughout meta-training

We have to Note that this dataset has masks describing PV potentially exploitable rooftop areas, and not just general rooftop areas. This is a much harder task for the model, comparing to simple rooftop segmentation. Even if the model cannot segment PV potentially exploitable rooftop areas very precisely, even with very limited supervision, the model still internalizes rooftop characteristics and reduces prediction errors effectively.

2. Predicted Masks & Quantitative Performance¶

To evaluate the model, we report few-shot segmentation results across different numbers of labeled support examples:

  • 1-shot: 1 labeled support example per test episode
  • 5-shot: 5 labeled support examples
  • 10-shot: 10 labeled support examples
  • 20-shot: 20 labeled support examples

On a set of 102 test tiles drawn from a geographically distinct region of Geneva, the model achieves a IoU that generally varies between ~0.43 and ~0.48.

While modest, these results are encouraging given:

  • the strong label constraints
  • the complexity of urban rooftop structures
  • the geographic domain shift between training and testing

A qualitative example highlights this behavior:

  • Support images & masks define the relevant rooftop characteristics
  • On the query image, the predicted mask captures major rooftop shapes
  • Large rooftop surfaces are generally identified
  • Fine-grained details remain imperfect, but the model generalizes to textures and geometries not seen during training

These results suggest that the Prototypical Network successfully learns a useful and transferable feature embedding.


Discussion of Potential Extensions¶

The results show that Prototypical Networks can learn meaningful rooftop representations even with very limited supervision. However, there remains substantial room to improve performance and explore alternative design choices within the few-shot learning setup. Several aspects of the training procedure could be refined to enhance segmentation accuracy:

Model Tuning and Regularization¶

Incorporating techniques such as weight decay, dropout or early stopping could stabilize feature learning and reduce overfitting to the small support sets typically used in few-shot learning.

Training for More Epochs¶

For demonstration purposes, the model was trained for only a limited number of epochs. Extending training duration or increasing the number of sampled episodes per epoch could help the encoder converge toward a more discriminative embedding space, potentially improving segmentation performance.

Extending the Task Toward Policy Relevance¶

A rough solar potential approximation could be built on top of the segmentation task. For example by combining the predicted rooftop area with IoU-based uncertainty estimates. This could provide a first-order indicator of solar suitability, connecting the model’s outputs to real-world energy planning applications.

Trying Different Encoder Backbones¶

The current prototype uses a lightweight CNN encoder for simplicity. Replacing it with stronger architectures such as ResNet-50 or a Vision Transformer (ViT) may yield more robust and generalizable feature representations.

Using More Complex Prototypical Network Architectures or More Advanced FSL Architectures¶

We evaluate a "vanilla" ProtNet, and this is a first implementation of such. The performances can be enhanced by using more complex architectures. Beyond a vanilla Prototypical Network, you can try relation-based models like Relation Networks, which learn a small neural “comparator” between query features and support prototypes instead of using a fixed distance metric—often improving robustness under domain shift. You can also use gradient-based meta-learning methods such as MAML/Reptile, which explicitly optimize the encoder to adapt quickly to a new task with a few support examples. Advanced architectures like transformer-based matching (cross-attention between support and query features) can model richer support–query interactions than a single prototype per class. This tutorial is hopefully a first step, and might pave the way to learn more!


Limitations¶

While this tutorial successfully demonstrates the core ideas behind few-shot segmentation with Prototypical Networks, several important simplifications limit its applicability. Many of these choices were made intentionally to ensure the tutorial remains computationally lightweight and easy to reproduce.

Simplified Experimental Setup¶

To keep the workflow accessible, we used:

  • a very small training set, both in the number of tiles and support examples per episode
  • a lightweight encoder, rather than higher-capacity backbones common in remote sensing (e.g., ResNet-50, Swin Transformer)
  • a short training schedule, with few epochs and limited episode sampling

These design choices improve reproducibility but also restrict the achievable segmentation performance.
In practical applications—such as large-scale rooftop or solar mapping—substantially more data, stronger feature extractors, and longer training would be necessary.

Modelling Choices Intentionally Kept Simple¶

Several simplifications reduce the robustness of the resulting predictions:

  • Binary segmentation (roof vs. non-roof) ignores roof type, material, shadows, and occlusions—all of which matter for accurate solar potential estimation.
  • No post-processing was applied (e.g., morphological filters, CRFs), even though such steps typically improve mask quality.
  • No uncertainty estimation was included, despite being crucial for planning and policy-relevant applications.

These omissions help focus on core concepts but limit real-world applicability.

Dataset Biases and Generalization Limits¶

The dataset itself introduces structural biases:

  • Imagery is taken exclusively from Geneva, a wealthy European city with relatively homogeneous architectural styles.
  • Rooftop morphology varies globally—informal housing, climate-adapted roof shapes, and diverse materials are not represented here.
  • The geographic split (Center; denser - Periphery; more sparse) creates a stylized domain shift, but does not reflect true global variation.

If applied uncritically in policy contexts, these limitations could reinforce geographic inequities. For example, overestimating solar potential in well-represented neighborhoods and underestimating it in underrepresented ones.

Addressing These Challenges¶

To improve real-world deployment, data analysts and practitioners should consider:

  • Expanding dataset diversity (more cities, varied roof types, lighting conditions, seasons).
  • Evaluating fairness and generalization across socioeconomic and geographic groups.
  • Incorporating uncertainty estimation, especially when predictions support infrastructure or planning decisions.
  • Validating model outputs with domain experts (urban planners, energy modelers, local authorities).

By recognizing these limitations, we can better understand the conditions under which few-shot rooftop segmentation performs well—and the steps required to make such models reliable for operational or policy-driven use.


⭐ Challenge ⭐¶

Photovoltaic Capacity Estimation in France via Prototypical Network Segmentation¶

Overview¶

In this challenge, you can check your new knowledge and apply your new skills to estimate PV capacities in France. For that you will use a dataset containing RGB aerial imagery of buildings and landscapes, polygon segmentation masks of photovoltaic (PV) arrays and installation metadata (location and nominal capacity) for thousands of solar installations across France. The dataset can be accessed here.

Your objectives are to:

  1. Segment PV arrays in unseen aerial RGB images using the provided segmentation masks
    (i.e., generate per-pixel segmentation masks of solar panels).

  2. Compare segmentation results with installation metadata (metadata.csv), which includes ground-truth installation characteristics.

  3. Evaluate:

    • (a) your practical application of Prototypical Network–based segmentation,
    • (b) the model’s ability to generalize to a new dataset,
    • (c) the accuracy of area-based PV capacity estimates when compared against reported capacities.

You can use following ressources to help you:

  • Published Paper:

    Nature Article

  • GitHub Repository (tools & implementation guidance):

    BDAP-PV Repository


Further Resources¶

Foundational Papers¶

  • Shaban, A., Bansal, S., Liu, Z., Essa, I., & Boots, B. (2017). One-Shot Learning for Semantic Segmentation. BMVC.
    https://doi.org/10.48550/arXiv.1709.03410
    → First paper to formally define the few-shot semantic segmentation task.
    → Introduces a conditioning branch that generates segmentation parameters from support examples.

  • Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical Networks for Few-Shot Learning. Advances in Neural Information Processing Systems (NeurIPS 2017), Vol. 30.
    https://doi.org/10.48550/arXiv.1703.05175
    → Establishes the concept of class prototypes derived from support embeddings.
    → Influences nearly all metric-based few-shot segmentation methods.

  • Wang, K., Liew, J. H., Zou, Y., Zhou, D., & Feng, J. (2019). PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment. ICCV.
    https://doi.org/10.48550/arXiv.1908.06391
    → Introduces prototype alignment networks to better utilize support-set information.
    → Serves as a strong and efficient baseline for few-shot segmentation.

  • Tian, Z., Zhao, H., Shu, M., Yang, Z., Li, R., & Jia, J. (2020). Prior Guided Feature Enrichment Network for Few-Shot Segmentation. IEEE TPAMI, 44(2), 1050–1065.
    https://doi.org/10.1109/TPAMI.2020.3013717
    → Introduces PFENet, significantly improving generalization.
    → Set a new benchmark after PANet for high-performance few-shot segmentation.

Other Resources: GitHub Repos and YT Videos¶

  • Code of the original implementation of Prototypical Networks (Snell et al., 2017)
    https://github.com/jakesnell/prototypical-networks

  • Video explanation of Prototypical Networks (deep dive into concepts + code)
    https://www.youtube.com/watch?v=rHGPfl0pvLY

  • Comprehensive GitHub repository on many few-shot learning techniques
    https://github.com/sicara/easy-few-shot-learning?tab=readme-ov-file

  • Prototypical Networks for Few-Shot Learning
    https://dancsalo.github.io/2020/12/24/prototypical/

  • Meta-Learning: Learning to Learn Fast
    https://lilianweng.github.io/posts/2018-11-30-meta-learning/

  • Transformer-based state-of-the-art segmentation: YOLOE (Zero-shot detection & segmentation)
    https://github.com/THU-MIG/yoloe

  • Transformer-based state-of-the-art segmentation: SAM Models (Segment Anything Model — Meta AI)
    https://github.com/facebookresearch/segment-anything


References¶

  • Alsentzer, E., Li, M. M., Kobren, S. N., Noori, A., Undiagnosed Diseases Network, Kohane, I. S., & Zitnik, M. (2025). Few shot learning for phenotype-driven diagnosis of patients with rare genetic diseases. npj Digital Medicine, 8(1), 380. https://doi.org/10.1038/s41746-025-01749-1

  • Castello, R., Walch, A., Attias, R., Cadei, R., Jiang, S., & Scartezzini, J.-L. (2021). Quantification of the suitable rooftop area for solar panel installation from overhead imagery using convolutional neural networks. Journal of Physics: Conference Series, 2042(1), 012002. https://doi.org/10.1088/1742-6596/2042/1/012002

  • Chen, Y., Wei, C., Wang, D., Ji, C., & Li, B. (2022). Semi-supervised contrastive learning for few-shot segmentation of remote sensing images. Remote Sensing, 14(17), 4254. https://doi.org/10.3390/rs14174254

  • Ding, H., Zhang, H., & Jiang, X. (2022). Self-regularized prototypical network for few-shot semantic segmentation. Pattern Recognition, 132, 109018. https://doi.org/10.1016/j.patcog.2022.109018

  • Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. In International Conference on Machine Learning (pp. 1126–1135). PMLR. https://doi.org/10.48550/arXiv.1703.03400

  • Ge, Z., Fan, X., Zhang, J., & Jin, S. (2025). SegPPD-FS: Segmenting plant pests and diseases in the wild using few-shot learning. Plant Phenomics, 100121. https://doi.org/10.1016/j.plaphe.2025.100121

  • Hu, Y., Liu, C., Li, Z., Xu, J., Han, Z., & Guo, J. (2022). Few-shot building footprint shape classification with relation network. ISPRS International Journal of Geo-Information, 11(5), 311. https://doi.org/10.3390/ijgi11050311

  • Jadon, S. (2021, February). COVID-19 detection from scarce chest x-ray image data using few-shot deep learning approach. In Medical Imaging 2021: Imaging Informatics for Healthcare, Research, and Applications (Vol. 11601, pp. 161–170). SPIE. https://doi.org/10.1117/12.2581496

  • Lee, G. Y., Dam, T., Ferdaus, M. M., Poenar, D. P., & Duong, V. (2025). Enhancing Few-Shot Classification of Benchmark and Disaster Imagery with ATTBHFA-Net. arXiv preprint arXiv:2510.18326. https://doi.org/10.48550/arXiv.2510.18326

  • Li, X., He, Z., Zhang, L., Guo, S., Hu, B., & Guo, K. (2025). CDCNet: Cross-domain few-shot learning with adaptive representation enhancement. Pattern Recognition, 162, 111382. https://doi.org/10.1016/j.patcog.2025.111382

  • Puthumanaillam, G., & Verma, U. (2023). Texture based prototypical network for few-shot semantic segmentation of forest cover: Generalizing for different geographical regions. Neurocomputing, 538, 126201. https://doi.org/10.1016/j.neucom.2023.03.062

  • Shaban, A., Bansal, S., Liu, Z., Essa, I., & Boots, B. (2017). One-shot learning for semantic segmentation. BMVC. https://doi.org/10.48550/arXiv.1709.03410

  • Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical networks for few-shot learning. Advances in Neural Information Processing Systems (NeurIPS 2017), Vol. 30. https://doi.org/10.48550/arXiv.1703.05175

  • Sung, F., Yang, Y., Zhang, L., Xiang, T., Torr, P. H., & Hospedales, T. M. (2018). Learning to compare: Relation network for few-shot learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 1199–1208). https://doi.org/10.1109/CVPR.2018.00131

  • Tian, Z., Zhao, H., Shu, M., Yang, Z., Li, R., & Jia, J. (2020). Prior guided feature enrichment network for few-shot segmentation. IEEE TPAMI, 44(2), 1050–1065. https://doi.org/10.1109/TPAMI.2020.3013717

  • Wang, K., Liew, J. H., Zou, Y., Zhou, D., & Feng, J. (2019). PANet: Few-shot image semantic segmentation with prototype alignment. ICCV. https://doi.org/10.48550/arXiv.1908.06391